{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import gym\n",
    "import torch\n",
    "import os\n",
    "from torch.utils.tensorboard import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, input_dims, output_dims):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(input_dims, 10)\n",
    "        self.relu1 = torch.nn.ReLU()\n",
    "        self.fc2 = torch.nn.Linear(10, output_dims)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu1(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "343f4162277aa3df"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def choose_action(actor_net, state):\n",
    "    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)\n",
    "    with torch.no_grad():\n",
    "        actions_probability = actor_net(torch.tensor(state))\n",
    "    actions_probability = torch.softmax(actions_probability, dim=1)\n",
    "    action = torch.distributions.Categorical(actions_probability).sample()\n",
    "    return action.item()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "749dbcdbb80237c3"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def learn_critic(critic_net, critic_optimizer, state, reward, new_state, gamma):\n",
    "    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)\n",
    "    new_state = torch.tensor(new_state, dtype=torch.float32).unsqueeze(0)\n",
    "    reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)\n",
    "    with torch.no_grad():\n",
    "        v_ = critic_net(new_state)\n",
    "    v = critic_net(state)\n",
    "    td_error = gamma * v_ + reward - v\n",
    "    loss = td_error.pow(2)\n",
    "    critic_optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    critic_optimizer.step()\n",
    "    return td_error.item()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4fd74f07de14e2ee"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def learn_actor(actor_net, actor_optimizer, state, action, td_error):\n",
    "    # 1\n",
    "    probability = torch.nn.Softmax()(actor_net(torch.Tensor(state)))\n",
    "    log_prob = torch.log(probability)\n",
    "    loss = - log_prob[action] * td_error\n",
    "    actor_optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    actor_optimizer.step()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8b1ab297d33e253"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "log_dir = './runs'\n",
    "if os.path.exists(log_dir):\n",
    "    try:\n",
    "        shutil.rmtree(log_dir)\n",
    "        print(f'文件夹 {log_dir} 已成功删除。')\n",
    "    except OSError as error:\n",
    "        print(f'删除文件夹 {log_dir} 失败: {error}')\n",
    "else:\n",
    "    os.makedirs(log_dir)\n",
    "    print(f'文件夹 {log_dir} 不存在，已创建文件夹 {log_dir}。')"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "65145df6748ea13c",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "summary_writer = SummaryWriter(log_dir=log_dir)\n",
    "env = gym.make('CartPole-v1')\n",
    "gamma = 0.98\n",
    "n_actions = env.action_space.n\n",
    "n_features = env.observation_space.shape[0]\n",
    "actor_net = Net(n_features, n_actions)\n",
    "actor_optimizer = torch.optim.Adam(actor_net.parameters(), lr=1e-3)\n",
    "critic_net = Net(n_features, 1)\n",
    "critic_optimizer = torch.optim.Adam(critic_net.parameters(), lr=1e-3)\n",
    "\n",
    "episodes = 5000\n",
    "steps = 5000\n",
    "for episode in range(episodes):\n",
    "    start_time = time.time()\n",
    "    state, _ = env.reset()\n",
    "    step = 0\n",
    "    while step <= steps:\n",
    "        _action = choose_action(actor_net, state)\n",
    "        new_state, reward, done, _, _ = env.step(_action)\n",
    "        if done:\n",
    "            reward = -20\n",
    "        # else:\n",
    "        #     reward = 0\n",
    "        td_error = learn_critic(critic_net, critic_optimizer, state, reward, new_state, gamma)\n",
    "        learn_actor(actor_net, actor_optimizer, state, _action, td_error)\n",
    "        step += 1\n",
    "        state = new_state\n",
    "        if done:\n",
    "            summary_writer.add_scalar('step', step, episode)\n",
    "            print('Episode: {}/{}  | Step: {}  | Running Time: {:.4f}'.format(episode,\n",
    "                                                                              episodes,\n",
    "                                                                              step,\n",
    "                                                                              time.time() - start_time))\n",
    "            break\n",
    "    if step >= 1000:\n",
    "        break\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "6b0d41c2febcf1ff"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "env = gym.make(\"CartPole-v1\", render_mode='human')\n",
    "state, _ = env.reset()\n",
    "step = 0\n",
    "total_reward = 0\n",
    "while True:\n",
    "    state = torch.tensor(state)\n",
    "    a = choose_action(actor_net, state)\n",
    "    # a = actor_net(state)\n",
    "    # a = int(torch.argmax(a))\n",
    "    new_state, reward, done, _, _ = env.step(a)\n",
    "    step += 1\n",
    "    state = new_state\n",
    "    if done:\n",
    "        break\n",
    "print('step:{}'.format(step))\n",
    "env.close()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "5af05addcf744d3b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "env.close()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "52967f6d08a58951"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
